/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "common_reg.h"
#include "execute_graph_builder.h"
#include "execute_node_builder.h"
#include <gtest/gtest.h>

namespace ge {
class ExecuteNodeBuilderUt : public testing::Test {};

TEST_F(ExecuteNodeBuilderUt, BuildNodeOk) {
  ExecuteGraphBuilder gb;
  auto registry = RegOpDefs(gb);
  auto nb = gb.GetNodeBuilder(gb.AddNode());

  nb->SetOpDef(registry.matmul)
      .SetName("HelloMatMul")
      .SetInputName(0, "x1")
      .SetInputName(1, "x2")
      .SetOutputName(0, "y");

  auto graph = gb.Build();
  EXPECT_NE(graph, nullptr);
  EXPECT_EQ(graph->GetNodeCount(), 1);
  auto node = graph->GetNodeByIndex(0);
  EXPECT_NE(node, nullptr);
  auto op_desc = node->GetOpDesc();
  EXPECT_NE(op_desc, nullptr);
  EXPECT_STREQ(op_desc->GetName(), "HelloMatMul");
  EXPECT_STREQ(op_desc->GetType(), "MatMul");
  EXPECT_EQ(op_desc->GetInputDescCount(), 2);
  EXPECT_EQ(op_desc->GetOutputDescCount(), 1);
  EXPECT_EQ(op_desc->GetAttrStore().GetIdByName("transpose_x1"), 0);
  EXPECT_EQ(op_desc->GetAttrStore().GetIdByName("transpose_x2"), 1);
  EXPECT_EQ(*(op_desc->GetAttrStore().MutableGet<bool>(0)), false);
  EXPECT_EQ(*(op_desc->GetAttrStore().MutableGet<bool>(1)), false);
}

TEST_F(ExecuteNodeBuilderUt, BuildNodeSetAttrOk1) {
  ExecuteGraphBuilder gb;
  auto registry = RegOpDefs(gb);
  auto nb = gb.GetNodeBuilder(gb.AddNode());

  nb->SetOpDef(registry.matmul)
      .SetName("HelloMatMul")
      .SetInputName(0, "x1")
      .SetInputName(1, "x2")
      .SetOutputName(0, "y")
      .SetAttr(0, true);

  auto graph = gb.Build();
  auto op_desc = graph->GetNodeByIndex(0)->GetOpDesc();
  EXPECT_EQ(*(op_desc->GetAttrStore().MutableGet<bool>(0)), true);
  EXPECT_EQ(*(op_desc->GetAttrStore().MutableGet<bool>(1)), false);
}

TEST_F(ExecuteNodeBuilderUt, BuildNodeDefaultAttrOk2) {
  ExecuteGraphBuilder gb;
  auto registry = RegOpDefs(gb);
  auto nb = gb.GetNodeBuilder(gb.AddNode());

  nb->SetOpDef(registry.conv2d)
      .SetName("Hello_conv2d")
      .SetInputName(0, "x")
      .SetInputName(1, "filter")
      .SetOutputName(0, "y")
      .SetAttr(0, 1)
      .SetAttr(1, 10);

  auto graph = gb.Build();
  auto op_desc = graph->GetNodeByIndex(0)->GetOpDesc();
  EXPECT_EQ(*(op_desc->GetAttrStore().MutableGet<int>(0)), 1);
  EXPECT_EQ(*(op_desc->GetAttrStore().MutableGet<int>(1)), 10);
  EXPECT_EQ(*(op_desc->GetAttrStore().MutableGet<std::vector<int64_t>>(2)), std::vector<int64_t>({1, 1, 1, 1}));
  EXPECT_EQ(*(op_desc->GetAttrStore().MutableGet<int>(3)), 1);
  EXPECT_EQ(*(op_desc->GetAttrStore().MutableGet<std::string>(4)), std::string("NHWC"));
  EXPECT_EQ(*(op_desc->GetAttrStore().MutableGet<int>(5)), 0);
  EXPECT_EQ(op_desc->GetAttrStore().MutableGet<int>(6), nullptr);

  EXPECT_EQ(op_desc->GetAttrStore().GetIdByName("strides"), 0);
  EXPECT_EQ(op_desc->GetAttrStore().GetIdByName("pads"), 1);
  EXPECT_EQ(op_desc->GetAttrStore().GetIdByName("dilations"), 2);
  EXPECT_EQ(op_desc->GetAttrStore().GetIdByName("groups"), 3);
  EXPECT_EQ(op_desc->GetAttrStore().GetIdByName("data_format"), 4);
  EXPECT_EQ(op_desc->GetAttrStore().GetIdByName("offset_x"), 5);
}

TEST_F(ExecuteNodeBuilderUt, BuildNodeRequiredAttrNotExsits) {
  ExecuteGraphBuilder gb;
  auto registry = RegOpDefs(gb);
  auto nb = gb.GetNodeBuilder(gb.AddNode());

  nb->SetOpDef(registry.conv2d)
      .SetName("Hello_conv2d")
      .SetInputName(0, "x")
      .SetInputName(1, "filter")
      .SetOutputName(0, "y")
      .SetAttr(0, 1);

  auto graph = gb.Build();
  EXPECT_EQ(graph, nullptr);
}

TEST_F(ExecuteNodeBuilderUt, BuildNodeInvalidAttrSpecified) {
  ExecuteGraphBuilder gb;
  auto registry = RegOpDefs(gb);
  auto nb = gb.GetNodeBuilder(gb.AddNode());

  nb->SetOpDef(registry.matmul)
      .SetName("HelloMatMul")
      .SetInputName(0, "x1")
      .SetInputName(1, "x2")
      .SetOutputName(0, "y")
      .SetAttr(2, 10);

  auto graph = gb.Build();
  EXPECT_EQ(graph, nullptr);
}
}  // namespace ge